// Copyright Epic Games, Inc. All Rights Reserved.

#include <zencore/workthreadpool.h>

#include <zencore/blockingqueue.h>
#include <zencore/except.h>
#include <zencore/logging.h>
#include <zencore/string.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
#include <zencore/trace.h>

#include <thread>
#include <vector>

#define ZEN_USE_WINDOWS_THREADPOOL 1

#if ZEN_PLATFORM_WINDOWS && ZEN_USE_WINDOWS_THREADPOOL
#	include <zencore/windows.h>
#endif

namespace zen {

namespace detail {
	struct LambdaWork : IWork
	{
		LambdaWork(auto Work) : WorkFunction(Work) {}
		virtual void Execute() override { WorkFunction(); }

		std::function<void()> WorkFunction;
	};
}  // namespace detail

//////////////////////////////////////////////////////////////////////////

#if ZEN_USE_WINDOWS_THREADPOOL && ZEN_PLATFORM_WINDOWS

namespace {
	thread_local bool t_IsThreadNamed{false};
}

struct WorkerThreadPool::Impl
{
	PTP_POOL			m_ThreadPool   = nullptr;
	PTP_CLEANUP_GROUP	m_CleanupGroup = nullptr;
	TP_CALLBACK_ENVIRON m_CallbackEnvironment;
	PTP_WORK			m_Work = nullptr;

	std::string		 m_WorkerThreadBaseName;
	std::atomic<int> m_WorkerThreadCounter{0};

	RwLock				   m_QueueLock;
	std::deque<Ref<IWork>> m_WorkQueue;

	Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName)
	{
		// Thread pool setup

		m_ThreadPool = CreateThreadpool(NULL);
		if (m_ThreadPool == NULL)
		{
			ThrowLastError("CreateThreadpool failed");
		}

		if (!SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount))
		{
			ThrowLastError("SetThreadpoolThreadMinimum failed");
		}
		SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2);

		InitializeThreadpoolEnvironment(&m_CallbackEnvironment);

		m_CleanupGroup = CreateThreadpoolCleanupGroup();
		if (m_CleanupGroup == NULL)
		{
			ThrowLastError("CreateThreadpoolCleanupGroup failed");
		}

		SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool);
		SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL);

		m_Work = CreateThreadpoolWork(&WorkCallback, this, &m_CallbackEnvironment);
		if (m_Work == NULL)
		{
			ThrowLastError("CreateThreadpoolWork failed");
		}
	}

	~Impl()
	{
		WaitForThreadpoolWorkCallbacks(m_Work, /* CancelPendingCallbacks */ TRUE);
		CloseThreadpoolWork(m_Work);
		CloseThreadpool(m_ThreadPool);
	}

	void ScheduleWork(Ref<IWork> Work)
	{
		m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); });
		SubmitThreadpoolWork(m_Work);
	}
	[[nodiscard]] size_t PendingWorkItemCount() const { return 0; }

	static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work)
	{
		ZEN_UNUSED(Instance, Work);
		Impl* ThisPtr = reinterpret_cast<Impl*>(Context);
		ThisPtr->DoWork();
	}

	void DoWork()
	{
		if (!t_IsThreadNamed)
		{
			t_IsThreadNamed								  = true;
			const int						  ThreadIndex = ++m_WorkerThreadCounter;
			zen::ExtendableStringBuilder<128> ThreadName;
			ThreadName << m_WorkerThreadBaseName << "_" << ThreadIndex;
			SetCurrentThreadName(ThreadName);
		}

		Ref<IWork> WorkFromQueue;

		{
			RwLock::ExclusiveLockScope _{m_QueueLock};
			WorkFromQueue = std::move(m_WorkQueue.front());
			m_WorkQueue.pop_front();
		}

		ZEN_TRACE_CPU_FLUSH("AsyncWork");
		WorkFromQueue->Execute();
	}
};

#else

struct WorkerThreadPool::ThreadStartInfo
{
	int			ThreadNumber;
	zen::Latch* Latch;
};

struct WorkerThreadPool::Impl
{
	void					  WorkerThreadFunction(ThreadStartInfo Info);
	std::string				  m_WorkerThreadBaseName;
	std::vector<std::thread>  m_WorkerThreads;
	BlockingQueue<Ref<IWork>> m_WorkQueue;

	Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName)
	{
#	if ZEN_WITH_TRACE
		trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str());
#	endif

		zen::Latch WorkerLatch{InThreadCount};

		for (int i = 0; i < InThreadCount; ++i)
		{
			m_WorkerThreads.emplace_back(&Impl::WorkerThreadFunction, this, ThreadStartInfo{i + 1, &WorkerLatch});
		}

		WorkerLatch.Wait();

#	if ZEN_WITH_TRACE
		trace::ThreadGroupEnd();
#	endif
	}

	~Impl()
	{
		m_WorkQueue.CompleteAdding();

		for (std::thread& Thread : m_WorkerThreads)
		{
			if (Thread.joinable())
			{
				Thread.join();
			}
		}

		m_WorkerThreads.clear();
	}

	void				 ScheduleWork(Ref<IWork> Work) { m_WorkQueue.Enqueue(std::move(Work)); }
	[[nodiscard]] size_t PendingWorkItemCount() const { return m_WorkQueue.Size(); }
};

void
WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info)
{
	SetCurrentThreadName(fmt::format("{}_{}", m_WorkerThreadBaseName, Info.ThreadNumber));

	Info.Latch->CountDown();

	do
	{
		Ref<IWork> Work;
		if (m_WorkQueue.WaitAndDequeue(Work))
		{
			try
			{
				ZEN_TRACE_CPU_FLUSH("AsyncWork");
				Work->Execute();
			}
			catch (const AssertException& Ex)
			{
				Work->m_Exception = std::current_exception();

				ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription());
			}
			catch (const std::exception& e)
			{
				Work->m_Exception = std::current_exception();

				ZEN_WARN("Caught exception in worker thread: {}", e.what());
			}
		}
		else
		{
			return;
		}
	} while (true);
}

#endif

//////////////////////////////////////////////////////////////////////////
WorkerThreadPool::WorkerThreadPool(int InThreadCount) : WorkerThreadPool(InThreadCount, "workerthread")
{
}

WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName)
{
	if (InThreadCount > 0)
	{
		m_Impl = std::make_unique<Impl>(InThreadCount, WorkerThreadBaseName);
	}
}

WorkerThreadPool::~WorkerThreadPool()
{
	m_Impl.reset();
}

void
WorkerThreadPool::ScheduleWork(Ref<IWork> Work)
{
	if (m_Impl)
	{
		m_Impl->ScheduleWork(std::move(Work));
	}
	else
	{
		try
		{
			ZEN_TRACE_CPU_FLUSH("SyncWork");
			Work->Execute();
		}
		catch (const AssertException& Ex)
		{
			Work->m_Exception = std::current_exception();

			ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription());
		}
		catch (const std::exception& e)
		{
			Work->m_Exception = std::current_exception();

			ZEN_WARN("Caught exception when executing worker synchronously: {}", e.what());
		}
	}
}

void
WorkerThreadPool::ScheduleWork(std::function<void()>&& Work)
{
	ScheduleWork(Ref<IWork>(new detail::LambdaWork(Work)));
}

[[nodiscard]] size_t
WorkerThreadPool::PendingWorkItemCount() const
{
	if (m_Impl)
	{
		return m_Impl->PendingWorkItemCount();
	}
	return 0;
}

//////////////////////////////////////////////////////////////////////////

#if ZEN_WITH_TESTS

void
workthreadpool_forcelink()
{
}

using namespace std::literals;

TEST_CASE("threadpool.basic")
{
	WorkerThreadPool Threadpool{1};

	auto Future42	 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 42; }});
	auto Future99	 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 99; }});
	auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task<void()>{[] { throw std::runtime_error("meep!"); }});

	CHECK_EQ(Future42.get(), 42);
	CHECK_EQ(Future99.get(), 99);
	CHECK_THROWS(FutureThrow.get());
}

#endif

}  // namespace zen
